{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s_qNSzzyaCbD"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "jmjh290raIky"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J0Qjg6vuaHNt"
      },
      "source": [
        "# 基于注意力的神经机器翻译"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AOpGoE2T-YXS"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/text/nmt_with_attention\">\n",
        "    <img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />\n",
        "    在 TensorFlow.org 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/text/nmt_with_attention.ipynb\">\n",
        "    <img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />\n",
        "    在 Google Colab 运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/text/nmt_with_attention.ipynb\">\n",
        "    <img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />\n",
        "    在 GitHub 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/text/nmt_with_attention.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载此 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8dEwzVWg0f-E"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CiwtNgENbx2g"
      },
      "source": [
        "此笔记本训练一个将西班牙语翻译为英语的序列到序列（sequence to sequence，简写为 seq2seq）模型。此例子难度较高，需要对序列到序列模型的知识有一定了解。\n",
        "\n",
        "训练完此笔记本中的模型后，你将能够输入一个西班牙语句子，例如 *\"¿todavia estan en casa?\"*，并返回其英语翻译 *\"are you still at home?\"*\n",
        "\n",
        "对于一个简单的例子来说，翻译质量令人满意。但是更有趣的可能是生成的注意力图：它显示在翻译过程中，输入句子的哪些部分受到了模型的注意。\n",
        "\n",
        "<img src=\"https://tensorflow.google.cn/images/spanish-english.png\" alt=\"spanish-english attention plot\">\n",
        "\n",
        "请注意：运行这个例子用一个 P100 GPU 需要花大约 10 分钟。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tnxXKDjq3jEL"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.ticker as ticker\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "import unicodedata\n",
        "import re\n",
        "import numpy as np\n",
        "import os\n",
        "import io\n",
        "import time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wfodePkj3jEa"
      },
      "source": [
        "## 下载和准备数据集\n",
        "\n",
        "我们将使用 http://www.manythings.org/anki/ 提供的一个语言数据集。这个数据集包含如下格式的语言翻译对：\n",
        "\n",
        "```\n",
        "May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
        "```\n",
        "\n",
        "这个数据集中有很多种语言可供选择。我们将使用英语 - 西班牙语数据集。为方便使用，我们在谷歌云上提供了此数据集的一份副本。但是你也可以自己下载副本。下载完数据集后，我们将采取下列步骤准备数据：\n",
        "\n",
        "1. 给每个句子添加一个 *开始* 和一个 *结束* 标记（token）。\n",
        "2. 删除特殊字符以清理句子。\n",
        "3. 创建一个单词索引和一个反向单词索引（即一个从单词映射至 id 的词典和一个从 id 映射至单词的词典）。\n",
        "4. 将每个句子填充（pad）到最大长度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kRVATYOgJs1b"
      },
      "outputs": [],
      "source": [
        "# 下载文件\n",
        "path_to_zip = tf.keras.utils.get_file(\n",
        "    'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',\n",
        "    extract=True)\n",
        "\n",
        "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rd0jw-eC3jEh"
      },
      "outputs": [],
      "source": [
        "# 将 unicode 文件转换为 ascii\n",
        "def unicode_to_ascii(s):\n",
        "    return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
        "        if unicodedata.category(c) != 'Mn')\n",
        "\n",
        "\n",
        "def preprocess_sentence(w):\n",
        "    w = unicode_to_ascii(w.lower().strip())\n",
        "\n",
        "    # 在单词与跟在其后的标点符号之间插入一个空格\n",
        "    # 例如： \"he is a boy.\" => \"he is a boy .\"\n",
        "    # 参考：https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
        "    w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
        "    w = re.sub(r'[\" \"]+', \" \", w)\n",
        "\n",
        "    # 除了 (a-z, A-Z, \".\", \"?\", \"!\", \",\")，将所有字符替换为空格\n",
        "    w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
        "\n",
        "    w = w.rstrip().strip()\n",
        "\n",
        "    # 给句子加上开始和结束标记\n",
        "    # 以便模型知道何时开始和结束预测\n",
        "    w = '<start> ' + w + ' <end>'\n",
        "    return w"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "opI2GzOt479E"
      },
      "outputs": [],
      "source": [
        "en_sentence = u\"May I borrow this book?\"\n",
        "sp_sentence = u\"¿Puedo tomar prestado este libro?\"\n",
        "print(preprocess_sentence(en_sentence))\n",
        "print(preprocess_sentence(sp_sentence).encode('utf-8'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OHn4Dct23jEm"
      },
      "outputs": [],
      "source": [
        "# 1. 去除重音符号\n",
        "# 2. 清理句子\n",
        "# 3. 返回这样格式的单词对：[ENGLISH, SPANISH]\n",
        "def create_dataset(path, num_examples):\n",
        "    lines = io.open(path, encoding='UTF-8').read().strip().split('\\n')\n",
        "\n",
        "    word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')]  for l in lines[:num_examples]]\n",
        "\n",
        "    return zip(*word_pairs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cTbSbBz55QtF"
      },
      "outputs": [],
      "source": [
        "en, sp = create_dataset(path_to_file, None)\n",
        "print(en[-1])\n",
        "print(sp[-1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OmMZQpdO60dt"
      },
      "outputs": [],
      "source": [
        "def max_length(tensor):\n",
        "    return max(len(t) for t in tensor)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bIOn8RCNDJXG"
      },
      "outputs": [],
      "source": [
        "def tokenize(lang):\n",
        "  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(\n",
        "      filters='')\n",
        "  lang_tokenizer.fit_on_texts(lang)\n",
        "\n",
        "  tensor = lang_tokenizer.texts_to_sequences(lang)\n",
        "\n",
        "  tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,\n",
        "                                                         padding='post')\n",
        "\n",
        "  return tensor, lang_tokenizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eAY9k49G3jE_"
      },
      "outputs": [],
      "source": [
        "def load_dataset(path, num_examples=None):\n",
        "    # 创建清理过的输入输出对\n",
        "    targ_lang, inp_lang = create_dataset(path, num_examples)\n",
        "\n",
        "    input_tensor, inp_lang_tokenizer = tokenize(inp_lang)\n",
        "    target_tensor, targ_lang_tokenizer = tokenize(targ_lang)\n",
        "\n",
        "    return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GOi42V79Ydlr"
      },
      "source": [
        "### 限制数据集的大小以加快实验速度（可选）\n",
        "\n",
        "在超过 10 万个句子的完整数据集上训练需要很长时间。为了更快地训练，我们可以将数据集的大小限制为 3 万个句子（当然，翻译质量也会随着数据的减少而降低）："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cnxC7q-j3jFD"
      },
      "outputs": [],
      "source": [
        "# 尝试实验不同大小的数据集\n",
        "num_examples = 30000\n",
        "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)\n",
        "\n",
        "# 计算目标张量的最大长度 （max_length）\n",
        "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4QILQkOs3jFG"
      },
      "outputs": [],
      "source": [
        "# 采用 80 - 20 的比例切分训练集和验证集\n",
        "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
        "\n",
        "# 显示长度\n",
        "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lJPmLZGMeD5q"
      },
      "outputs": [],
      "source": [
        "def convert(lang, tensor):\n",
        "  for t in tensor:\n",
        "    if t!=0:\n",
        "      print (\"%d ----> %s\" % (t, lang.index_word[t]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VXukARTDd7MT"
      },
      "outputs": [],
      "source": [
        "print (\"Input Language; index to word mapping\")\n",
        "convert(inp_lang, input_tensor_train[0])\n",
        "print ()\n",
        "print (\"Target Language; index to word mapping\")\n",
        "convert(targ_lang, target_tensor_train[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rgCLkfv5uO3d"
      },
      "source": [
        "### 创建一个 tf.data 数据集"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TqHsArVZ3jFS"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = len(input_tensor_train)\n",
        "BATCH_SIZE = 64\n",
        "steps_per_epoch = len(input_tensor_train)//BATCH_SIZE\n",
        "embedding_dim = 256\n",
        "units = 1024\n",
        "vocab_inp_size = len(inp_lang.word_index)+1\n",
        "vocab_tar_size = len(targ_lang.word_index)+1\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
        "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qc6-NK1GtWQt"
      },
      "outputs": [],
      "source": [
        "example_input_batch, example_target_batch = next(iter(dataset))\n",
        "example_input_batch.shape, example_target_batch.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TNfHIF71ulLu"
      },
      "source": [
        "## 编写编码器 （encoder） 和解码器 （decoder） 模型\n",
        "\n",
        "实现一个基于注意力的编码器 - 解码器模型。关于这种模型，你可以阅读 TensorFlow 的 [神经机器翻译 (序列到序列) 教程](https://github.com/tensorflow/nmt)。本示例采用一组更新的 API。此笔记本实现了上述序列到序列教程中的 [注意力方程式](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism)。下图显示了注意力机制为每个输入单词分配一个权重，然后解码器将这个权重用于预测句子中的下一个单词。下图和公式是 [Luong 的论文](https://arxiv.org/abs/1508.04025v5)中注意力机制的一个例子。\n",
        "\n",
        "<img src=\"https://tensorflow.google.cn/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\">\n",
        "\n",
        "输入经过编码器模型，编码器模型为我们提供形状为 *(批大小，最大长度，隐藏层大小)* 的编码器输出和形状为 *(批大小，隐藏层大小)* 的编码器隐藏层状态。\n",
        "\n",
        "下面是所实现的方程式：\n",
        "\n",
        "<img src=\"https://tensorflow.google.cn/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\">\n",
        "<img src=\"https://tensorflow.google.cn/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\">\n",
        "\n",
        "本教程的编码器采用 [Bahdanau 注意力](https://arxiv.org/pdf/1409.0473.pdf)。在用简化形式编写之前，让我们先决定符号：\n",
        "\n",
        "* FC = 完全连接（密集）层\n",
        "* EO = 编码器输出\n",
        "* H = 隐藏层状态\n",
        "* X = 解码器输入\n",
        "\n",
        "以及伪代码：\n",
        "\n",
        "* `score = FC(tanh(FC(EO) + FC(H)))`\n",
        "* `attention weights = softmax(score, axis = 1)`。 Softmax 默认被应用于最后一个轴，但是这里我们想将它应用于 *第一个轴*, 因为分数 （score） 的形状是 *(批大小，最大长度，隐藏层大小)*。最大长度 （`max_length`） 是我们的输入的长度。因为我们想为每个输入分配一个权重，所以 softmax 应该用在这个轴上。\n",
        "* `context vector = sum(attention weights * EO, axis = 1)`。选择第一个轴的原因同上。\n",
        "* `embedding output` = 解码器输入 X 通过一个嵌入层。\n",
        "* `merged vector = concat(embedding output, context vector)`\n",
        "* 此合并后的向量随后被传送到 GRU\n",
        "\n",
        "每个步骤中所有向量的形状已在代码的注释中阐明："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nZ2rI24i3jFg"
      },
      "outputs": [],
      "source": [
        "class Encoder(tf.keras.Model):\n",
        "  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
        "    super(Encoder, self).__init__()\n",
        "    self.batch_sz = batch_sz\n",
        "    self.enc_units = enc_units\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "    self.gru = tf.keras.layers.GRU(self.enc_units,\n",
        "                                   return_sequences=True,\n",
        "                                   return_state=True,\n",
        "                                   recurrent_initializer='glorot_uniform')\n",
        "\n",
        "  def call(self, x, hidden):\n",
        "    x = self.embedding(x)\n",
        "    output, state = self.gru(x, initial_state = hidden)\n",
        "    return output, state\n",
        "\n",
        "  def initialize_hidden_state(self):\n",
        "    return tf.zeros((self.batch_sz, self.enc_units))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "60gSVh05Jl6l"
      },
      "outputs": [],
      "source": [
        "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
        "\n",
        "# 样本输入\n",
        "sample_hidden = encoder.initialize_hidden_state()\n",
        "sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)\n",
        "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n",
        "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "umohpBN2OM94"
      },
      "outputs": [],
      "source": [
        "class BahdanauAttention(tf.keras.layers.Layer):\n",
        "  def __init__(self, units):\n",
        "    super(BahdanauAttention, self).__init__()\n",
        "    self.W1 = tf.keras.layers.Dense(units)\n",
        "    self.W2 = tf.keras.layers.Dense(units)\n",
        "    self.V = tf.keras.layers.Dense(1)\n",
        "\n",
        "  def call(self, query, values):\n",
        "    # 隐藏层的形状 == （批大小，隐藏层大小）\n",
        "    # hidden_with_time_axis 的形状 == （批大小，1，隐藏层大小）\n",
        "    # 这样做是为了执行加法以计算分数  \n",
        "    hidden_with_time_axis = tf.expand_dims(query, 1)\n",
        "\n",
        "    # 分数的形状 == （批大小，最大长度，1）\n",
        "    # 我们在最后一个轴上得到 1， 因为我们把分数应用于 self.V\n",
        "    # 在应用 self.V 之前，张量的形状是（批大小，最大长度，单位）\n",
        "    score = self.V(tf.nn.tanh(\n",
        "        self.W1(values) + self.W2(hidden_with_time_axis)))\n",
        "\n",
        "    # 注意力权重 （attention_weights） 的形状 == （批大小，最大长度，1）\n",
        "    attention_weights = tf.nn.softmax(score, axis=1)\n",
        "\n",
        "    # 上下文向量 （context_vector） 求和之后的形状 == （批大小，隐藏层大小）\n",
        "    context_vector = attention_weights * values\n",
        "    context_vector = tf.reduce_sum(context_vector, axis=1)\n",
        "\n",
        "    return context_vector, attention_weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k534zTHiDjQU"
      },
      "outputs": [],
      "source": [
        "attention_layer = BahdanauAttention(10)\n",
        "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n",
        "\n",
        "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n",
        "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yJ_B3mhW3jFk"
      },
      "outputs": [],
      "source": [
        "class Decoder(tf.keras.Model):\n",
        "  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
        "    super(Decoder, self).__init__()\n",
        "    self.batch_sz = batch_sz\n",
        "    self.dec_units = dec_units\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "    self.gru = tf.keras.layers.GRU(self.dec_units,\n",
        "                                   return_sequences=True,\n",
        "                                   return_state=True,\n",
        "                                   recurrent_initializer='glorot_uniform')\n",
        "    self.fc = tf.keras.layers.Dense(vocab_size)\n",
        "\n",
        "    # 用于注意力\n",
        "    self.attention = BahdanauAttention(self.dec_units)\n",
        "\n",
        "  def call(self, x, hidden, enc_output):\n",
        "    # 编码器输出 （enc_output） 的形状 == （批大小，最大长度，隐藏层大小）\n",
        "    context_vector, attention_weights = self.attention(hidden, enc_output)\n",
        "\n",
        "    # x 在通过嵌入层后的形状 == （批大小，1，嵌入维度）\n",
        "    x = self.embedding(x)\n",
        "\n",
        "    # x 在拼接 （concatenation） 后的形状 == （批大小，1，嵌入维度 + 隐藏层大小）\n",
        "    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
        "\n",
        "    # 将合并后的向量传送到 GRU\n",
        "    output, state = self.gru(x)\n",
        "\n",
        "    # 输出的形状 == （批大小 * 1，隐藏层大小）\n",
        "    output = tf.reshape(output, (-1, output.shape[2]))\n",
        "\n",
        "    # 输出的形状 == （批大小，vocab）\n",
        "    x = self.fc(output)\n",
        "\n",
        "    return x, state, attention_weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P5UY8wko3jFp"
      },
      "outputs": [],
      "source": [
        "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
        "\n",
        "sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)),\n",
        "                                      sample_hidden, sample_output)\n",
        "\n",
        "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_ch_71VbIRfK"
      },
      "source": [
        "## 定义优化器和损失函数"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WmTHr5iV3jFr"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.keras.optimizers.Adam()\n",
        "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "    from_logits=True, reduction='none')\n",
        "\n",
        "def loss_function(real, pred):\n",
        "  mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
        "  loss_ = loss_object(real, pred)\n",
        "\n",
        "  mask = tf.cast(mask, dtype=loss_.dtype)\n",
        "  loss_ *= mask\n",
        "\n",
        "  return tf.reduce_mean(loss_)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DMVWzzsfNl4e"
      },
      "source": [
        "## 检查点（基于对象保存）"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zj8bXQTgNwrF"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
        "                                 encoder=encoder,\n",
        "                                 decoder=decoder)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hpObfY22IddU"
      },
      "source": [
        "## 训练\n",
        "\n",
        "1. 将 *输入* 传送至 *编码器*，编码器返回 *编码器输出* 和 *编码器隐藏层状态*。\n",
        "2. 将编码器输出、编码器隐藏层状态和解码器输入（即 *开始标记*）传送至解码器。\n",
        "3. 解码器返回 *预测* 和 *解码器隐藏层状态*。\n",
        "4. 解码器隐藏层状态被传送回模型，预测被用于计算损失。\n",
        "5. 使用 *教师强制 （teacher forcing）* 决定解码器的下一个输入。\n",
        "6. *教师强制* 是将 *目标词* 作为 *下一个输入* 传送至解码器的技术。\n",
        "7. 最后一步是计算梯度，并将其应用于优化器和反向传播。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sC9ArXSsVfqn"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(inp, targ, enc_hidden):\n",
        "  loss = 0\n",
        "\n",
        "  with tf.GradientTape() as tape:\n",
        "    enc_output, enc_hidden = encoder(inp, enc_hidden)\n",
        "\n",
        "    dec_hidden = enc_hidden\n",
        "\n",
        "    dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)\n",
        "\n",
        "    # 教师强制 - 将目标词作为下一个输入\n",
        "    for t in range(1, targ.shape[1]):\n",
        "      # 将编码器输出 （enc_output） 传送至解码器\n",
        "      predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
        "\n",
        "      loss += loss_function(targ[:, t], predictions)\n",
        "\n",
        "      # 使用教师强制\n",
        "      dec_input = tf.expand_dims(targ[:, t], 1)\n",
        "\n",
        "  batch_loss = (loss / int(targ.shape[1]))\n",
        "\n",
        "  variables = encoder.trainable_variables + decoder.trainable_variables\n",
        "\n",
        "  gradients = tape.gradient(loss, variables)\n",
        "\n",
        "  optimizer.apply_gradients(zip(gradients, variables))\n",
        "\n",
        "  return batch_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ddefjBMa3jF0"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 10\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "  start = time.time()\n",
        "\n",
        "  enc_hidden = encoder.initialize_hidden_state()\n",
        "  total_loss = 0\n",
        "\n",
        "  for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n",
        "    batch_loss = train_step(inp, targ, enc_hidden)\n",
        "    total_loss += batch_loss\n",
        "\n",
        "    if batch % 100 == 0:\n",
        "        print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
        "                                                     batch,\n",
        "                                                     batch_loss.numpy()))\n",
        "  # 每 2 个周期（epoch），保存（检查点）一次模型\n",
        "  if (epoch + 1) % 2 == 0:\n",
        "    checkpoint.save(file_prefix = checkpoint_prefix)\n",
        "\n",
        "  print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
        "                                      total_loss / steps_per_epoch))\n",
        "  print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mU3Ce8M6I3rz"
      },
      "source": [
        "## 翻译\n",
        "\n",
        "* 评估函数类似于训练循环，不同之处在于在这里我们不使用 *教师强制*。每个时间步的解码器输入是其先前的预测、隐藏层状态和编码器输出。\n",
        "* 当模型预测 *结束标记* 时停止预测。\n",
        "* 存储 *每个时间步的注意力权重*。\n",
        "\n",
        "请注意：对于一个输入，编码器输出仅计算一次。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EbQpyYs13jF_"
      },
      "outputs": [],
      "source": [
        "def evaluate(sentence):\n",
        "    attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
        "\n",
        "    sentence = preprocess_sentence(sentence)\n",
        "\n",
        "    inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n",
        "    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n",
        "                                                           maxlen=max_length_inp,\n",
        "                                                           padding='post')\n",
        "    inputs = tf.convert_to_tensor(inputs)\n",
        "\n",
        "    result = ''\n",
        "\n",
        "    hidden = [tf.zeros((1, units))]\n",
        "    enc_out, enc_hidden = encoder(inputs, hidden)\n",
        "\n",
        "    dec_hidden = enc_hidden\n",
        "    dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)\n",
        "\n",
        "    for t in range(max_length_targ):\n",
        "        predictions, dec_hidden, attention_weights = decoder(dec_input,\n",
        "                                                             dec_hidden,\n",
        "                                                             enc_out)\n",
        "\n",
        "        # 存储注意力权重以便后面制图\n",
        "        attention_weights = tf.reshape(attention_weights, (-1, ))\n",
        "        attention_plot[t] = attention_weights.numpy()\n",
        "\n",
        "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
        "\n",
        "        result += targ_lang.index_word[predicted_id] + ' '\n",
        "\n",
        "        if targ_lang.index_word[predicted_id] == '<end>':\n",
        "            return result, sentence, attention_plot\n",
        "\n",
        "        # 预测的 ID 被输送回模型\n",
        "        dec_input = tf.expand_dims([predicted_id], 0)\n",
        "\n",
        "    return result, sentence, attention_plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s5hQWlbN3jGF"
      },
      "outputs": [],
      "source": [
        "# 注意力权重制图函数\n",
        "def plot_attention(attention, sentence, predicted_sentence):\n",
        "    fig = plt.figure(figsize=(10,10))\n",
        "    ax = fig.add_subplot(1, 1, 1)\n",
        "    ax.matshow(attention, cmap='viridis')\n",
        "\n",
        "    fontdict = {'fontsize': 14}\n",
        "\n",
        "    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n",
        "    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
        "\n",
        "    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
        "    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
        "\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sl9zUHzg3jGI"
      },
      "outputs": [],
      "source": [
        "def translate(sentence):\n",
        "    result, sentence, attention_plot = evaluate(sentence)\n",
        "\n",
        "    print('Input: %s' % (sentence))\n",
        "    print('Predicted translation: {}'.format(result))\n",
        "\n",
        "    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
        "    plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n250XbnjOaqP"
      },
      "source": [
        "## 恢复最新的检查点并验证"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UJpT9D5_OgP6"
      },
      "outputs": [],
      "source": [
        "# 恢复检查点目录 （checkpoint_dir） 中最新的检查点\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WrAM0FDomq3E"
      },
      "outputs": [],
      "source": [
        "translate(u'hace mucho frio aqui.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zSx2iM36EZQZ"
      },
      "outputs": [],
      "source": [
        "translate(u'esta es mi vida.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A3LLCx3ZE0Ls"
      },
      "outputs": [],
      "source": [
        "translate(u'¿todavia estan en casa?')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DUQVLVqUE1YW"
      },
      "outputs": [],
      "source": [
        "# 错误的翻译\n",
        "translate(u'trata de averiguarlo.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RTe5P5ioMJwN"
      },
      "source": [
        "## 下一步\n",
        "\n",
        "* [下载一个不同的数据集](http://www.manythings.org/anki/)实验翻译，例如英语到德语或者英语到法语。\n",
        "* 实验在更大的数据集上训练，或者增加训练周期。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "nmt_with_attention.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
